# -*- coding: utf-8 -*-
"""
Created on Fri May  9 15:05:21 2025

@author: User
"""

# -*- coding: utf-8 -*-
"""
Created on Thu Jan  9 12:55:40 2025

@author: User
"""
import torch
from torch.utils import data
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from copy import deepcopy
import torch.nn as nn
from torch.nn import functional as F
import torch.nn.functional as F
import numpy as np
import sys, os
import torchvision.transforms.functional as TF

import matplotlib.pyplot as plt
import pickle
import random
from PIL import Image
import scipy
import time 
import gc


import os
dir_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(dir_path)

from lib_train import *
from lib_cnn import * 
from all_estimators import * 

k1=3
c_local = [1.0]
# c_global = [0.8,0.9,1.0,1.1,1.2]
# c_global = [1.0]
# print(C_z)
KSG_est = MI_Estimator([k1]).KSG
KSG_local_est = MI_Estimator([k1,c_local]).KSG_local
KSG_mixed = MI_Estimator([k1,c_local]).Mixed_KSG 



def extract_features(model, inputs, layer_index):
    """ Forward pass until the specified layer to extract features. """
    if layer_index == 'last':
        feature_model = torch.nn.Sequential(model.features)
    else:
        feature_model = torch.nn.Sequential(*list(model.features.children())[:layer_index])
    features = feature_model(inputs)
    return features

def forward_from_layer(model, features, layer_index):
    """ Forward pass from the specified layer to the output. """
    
    if layer_index == 'last':
        new_model = torch.nn.Sequential(model.classifier)  # Keep classifier
    else:
        new_features = torch.nn.Sequential(*list(model.features.children())[layer_index:])  # Extract from layer 4 onward
        new_model = torch.nn.Sequential(new_features, model.classifier)  # Keep classifier
        

    # new_model = torch.nn.Sequential(*list(model.children())[list(model.named_modules()).keys().index(layer_name) + 1:])
    return new_model(features.squeeze())

def compute_sensitivity(model, testloader, layer_name, epsilon=0.1):
    """ Compute sensitivity of each feature by perturbing it slightly at the chosen layer. """
    model.eval()
    sensitivities = []
    avg_energies = []
    
    with torch.no_grad():
        for inputs, labels, _ in testloader:
            # Extract features from the specified layer
            features = extract_features(model, inputs, layer_name)
            
            num_features = features.shape[1]
            batch_size = features.shape[0]
            
            # Compute average energy of each feature across samples
            avg_energy = torch.mean(features**2, dim=0).cpu().numpy()
            
            # Compute sensitivity by perturbing each feature
            sensitivity = np.zeros(num_features)
            for i in range(num_features):
                perturbed_features = features.clone()
                perturbed_features[:, i] += epsilon*np.random.randn()  # Small perturbation
                
                # Forward pass from perturbed features to output
                perturbed_outputs = forward_from_layer(model, perturbed_features, layer_name)
                original_outputs = forward_from_layer(model, features, layer_name)
                
                # Compute sensitivity as change in predicted probability
                sensitivity[i] = torch.mean(torch.abs(nn.Softmax(dim=1)(perturbed_outputs) - nn.Softmax(dim=1)(original_outputs))).cpu().numpy()
            
            sensitivities.append(sensitivity)
            avg_energies.append(avg_energy)
    
    # Convert to numpy arrays
    sensitivities = np.mean(np.array(sensitivities), axis=0)
    avg_energies = np.mean(np.array(avg_energies), axis=0)
    
    return sensitivities, avg_energies


def compute_sensitivity_label_changes(model, testloader, layer_name, epsilon=1000):
    """ Compute sensitivity of each feature by perturbing it slightly at the chosen layer. """
    model.eval()
    sensitivities = []
    avg_energies = []
    
    with torch.no_grad():
        for inputs, labels, _ in testloader:
            # Extract features from the specified layer
            features = extract_features(model, inputs, layer_name)
            
            num_features = features.shape[1]
            batch_size = features.shape[0]
            
            # Compute average energy of each feature across samples
            avg_energy = torch.mean(features**2, dim=0).cpu().numpy()
            
            # Compute sensitivity by perturbing each feature
            sensitivity = np.zeros(num_features)
            for i in range(num_features):
                perturbed_features = features.clone()
                perturbed_features[:, i] += epsilon * torch.randn_like(features[:, i])  # Small perturbation
                
                # Forward pass from perturbed features to output
                perturbed_outputs = forward_from_layer(model, perturbed_features, layer_name)
                original_outputs = forward_from_layer(model, features, layer_name)
                
                # Apply softmax to get class probabilities
                perturbed_probs = F.softmax(perturbed_outputs, dim=1)
                original_probs = F.softmax(original_outputs, dim=1)
                
                # Compute how often the predicted label changes
                perturbed_labels = torch.argmax(perturbed_probs, dim=1)
                original_labels = torch.argmax(original_probs, dim=1)
                
                sensitivity[i] = torch.mean((perturbed_labels != original_labels).float()).cpu().numpy()
            
            sensitivities.append(sensitivity)
            avg_energies.append(avg_energy)
    
    # Convert to numpy arrays
    sensitivities = np.mean(np.array(sensitivities), axis=0)
    avg_energies = np.mean(np.array(avg_energies), axis=0)
    
    return sensitivities, avg_energies


def MI_and_energy(model, testloader, layer_name):
    """ Compute Mutual Information (MI) between individual features and output labels. """
    model.eval()
    MI_values = []
    avg_energies = []
    
    feature_list = []
    label_list = []
    
    with torch.no_grad():
        for inputs, labels, _ in testloader:
            # Extract features from the specified layer
            features = extract_features(model, inputs, layer_name)
            
            feature_list.append(features.cpu().numpy())  # Store features
            label_list.append(labels.cpu().numpy())  # Store labels
    
    # Convert lists to numpy arrays
    feature_matrix = np.concatenate(feature_list, axis=0)  # Shape: (num_samples, num_features)
    labels = np.concatenate(label_list, axis=0)  # Shape: (num_samples,)
    
    # Compute average energy of each feature
    avg_energies = np.mean(feature_matrix**2, axis=0)
    feature_matrix = feature_matrix.squeeze()
    # labels = np.float32(labels)
    # Compute MI for each feature
    MI_values = np.array([KSG_mixed(feature_matrix[0:10000, i], labels[0:10000]) for i in range(feature_matrix.shape[1])])
    
    return MI_values, avg_energies


def plot_sensitivity_vs_energy(sensitivities, avg_energies):
    """ Scatter plot of sensitivity vs. average energy with log-scaled y-axis. """
    plt.scatter(avg_energies, sensitivities, alpha=0.7)
    plt.xlabel("Average Energy of Feature")
    plt.ylabel("Label Sensitivity")
    plt.title("Feature Sensitivity vs. Energy")
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)  # Improve grid visibility
# 
    plt.ylim(min(sensitivities) * 0.9, max(sensitivities) * 1.1)
    
    # Expand range slightly    # plt.gca().autoscale() 
    plt.show()


def plot_MI_vs_energy(MI_values, avg_energies):
    """ 
    Subplots: 
    (1) Scatter plot of Feature MI vs. average feature energy 
    (2) Average MI of features ≤ x-axis energy 
    """
    plt.rcParams.update({"font.family": "Times New Roman"})

    avg_energies = avg_energies.squeeze()
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # **Left Plot: Scatter Plot of MI vs. Feature Energy**
    axes[0].scatter(avg_energies, MI_values, alpha=0.7)
    axes[0].set_xlabel("Average Energy of Individual Features",fontsize=16)
    axes[0].set_ylabel("MI of Individual Feature \n and Output Label",fontsize=16)
    axes[0].set_title("Feature MI vs. Energy",fontsize=16)
    axes[0].grid(True, which="both", linestyle="--", linewidth=0.5)
    axes[0].set_ylim(min(MI_values) * 0.9, max(MI_values) * 1.1)  # Expand range slightly

    # **Right Plot: Average MI of Features ≤ x-axis Energy**
    sorted_energy = np.sort(avg_energies.squeeze())  # Sort energy values
    avg_MI = [np.mean(MI_values[avg_energies <= e]) for e in sorted_energy]  # Compute avg MI

    axes[1].plot(sorted_energy, avg_MI, marker="o", linestyle="-", alpha=0.7)
    axes[1].set_xlabel("Threshold Energy",fontsize=16)
    axes[1].set_ylabel("Average MI of Features \n ≤ Energy Threshold",fontsize=16)
    axes[1].set_title("Cumulative MI vs. Energy Threshold",fontsize=16)
    axes[1].grid(True, linestyle="--", linewidth=0.5)

    plt.tight_layout()
    # plt.show()


def genloaders_vision(loader_params):
    
    transform_train = transforms.Compose(
        [
          # torchvision.transforms.GaussianBlur(5, sigma=2.0),
          # torchvision.transforms.functional.rgb_to_grayscale
         transforms.ToTensor(),
         ])
    transform_test = transforms.Compose(
        [
            # torchvision.transforms.GaussianBlur(5, sigma=2.0),
         transforms.ToTensor(),
         ])


    
    if loader_params['dataset_name'] == 'FashionMNIST':
        dataset = torchvision.datasets.FashionMNIST(root=loader_params['root_folder'], train=True,
                                                download=True, transform=transform_train)
        dataset_test = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                                download=True, transform=transform_test)
        dataset.data = dataset.data.float()/255.0
        dataset_test.data = dataset_test.data.float()/255.0        
        
        dataset.data = dataset.data.unsqueeze(1)
        dataset_test.data = dataset_test.data.unsqueeze(1)
    
    if loader_params['dataset_name'] == 'MNIST':
        dataset = torchvision.datasets.MNIST(root=loader_params['root_folder'], train=True,
                                                download=True, transform=transform_train)
        dataset_test = torchvision.datasets.MNIST(root='./data', train=False,
                                                download=True, transform=transform_test)
        dataset.data = dataset.data.float()/255.0
        dataset_test.data = dataset_test.data.float()/255.0        
        
        dataset.data = dataset.data.unsqueeze(1)
        dataset_test.data = dataset_test.data.unsqueeze(1)
    
    elif loader_params['dataset_name'] == 'CIFAR10':
        dataset = torchvision.datasets.CIFAR10(root=loader_params['root_folder'], train=True,
                                                download=True, transform=transform_train)
        dataset_test = torchvision.datasets.CIFAR10(root='./data', train=False,
                                                download=True, transform=transform_test)
        dataset.data = torch.from_numpy(dataset.data)
        dataset_test.data = torch.from_numpy(dataset_test.data)
        
        dataset.data = dataset.data.float()/255.0
        dataset_test.data = dataset_test.data.float()/255.0        
        
        dataset.data = torch.permute(dataset.data,(0,3,1,2))
        dataset.targets = torch.from_numpy(np.array(dataset.targets))
        dataset_test.data = torch.permute(dataset_test.data,(0,3,1,2))
        dataset_test.targets = torch.from_numpy(np.array(dataset_test.targets))
        
    elif loader_params['dataset_name'] == 'CIFAR100':
        dataset = torchvision.datasets.CIFAR100(root=loader_params['root_folder'], train=True,
                                                download=True, transform=transform_train)
        dataset_test = torchvision.datasets.CIFAR100(root='./data', train=False,
                                                download=True, transform=transform_test)
        dataset.data = torch.permute(torch.from_numpy(dataset.data),(0,3,1,2))
        dataset.targets = torch.from_numpy(np.array(dataset.targets))
        dataset.data = dataset.data.float()/255.0
        dataset_test.data = torch.permute(torch.from_numpy(dataset_test.data),(0,3,1,2))
        dataset_test.targets = torch.from_numpy(np.array(dataset_test.targets))
        dataset_test.data = dataset_test.data.float()/255.0
        
        
    trainloader, testloader, IG_trainloader = genloaders(dataset.data.cuda(), dataset.targets.cuda(), 
                                                         dataset_test.data.cuda(), dataset_test.targets.cuda(), loader_params)
        
    return trainloader, testloader, IG_trainloader

def prerequisites():
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic = True

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    
    gc.collect()
    torch.cuda.empty_cache()
    

if __name__ == "__main__":
    
    prerequisites() 
    program_mode = 'normal' # normal or GT (Ground truth)
    save_mode = 'store' # store, load or none
    
    loader_params = {
        'dataset_name': 'CIFAR10',
        'conversion': 'none',
        'root_folder': '../data',
        'training_size': 10000, # 'full'
        'batch_size': 200,
        'IG_batch_size': 400, 
        'transform': None,
        'add_singleton': False,
        'convert_to_torch': False,
        }
    
    
    
    train_params = {
        'optimizer': 'Adam',
        'scheduler': {'name': None}, # 'step_size': 10, 'milestones':[10,20,30],'gamma':0.8, 'max_lr': 0.01}
        'init_rate': 0.0005,
        'total_epochs': 100,
        'weight_decay': 0, 
        'criterion': 'CrossEntropyLoss',
        'disp_epoch': False,
        'disp_loss_epoch': True,
        'disp_time_per_epoch': True, 
        'disp_loss_final': False, 
        'disp_accuracy_final': True
        }
 
    
    model_params={
        'type': CNN,
        'name': 'ShallowfatterCIFAR10',
        'in_channels': 3,
        'batchnorm': False,
        }
    
    model = model_params['type'](model_params['name'], model_params['in_channels'], batchnorm = model_params['batchnorm'])

    trainloader, testloader, IG_trainloader = genloaders_vision(loader_params)
    
    
    if save_mode == 'store':
        trained_model, all_losses = train_model_general(model,trainloader, train_params)
        torch.save(trained_model.state_dict(), "trained_model_cifar.pth")
        print("Model saved successfully!")
    elif save_mode == 'load':
        model.load_state_dict(torch.load("trained_model_cifar.pth"))
        model.eval()
        print("Model loaded successfully!")
    
    
    # Example usage:
    # Assume `trained_model` is your trained CNN model and `test_loader` is the test dataset loader.
    layer_index = 'last'
    # sensitivities, avg_energies = compute_sensitivity_label_changes(model, trainloader, layer_index)
    MI, avg_energies = MI_and_energy(model, testloader, layer_index)
    
    plot_MI_vs_energy(MI, avg_energies)

    
    
